// Copyright 2024 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package botapi

import (
	"context"
	"fmt"
	"net/url"
	"time"

	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/types/known/timestamppb"

	"go.chromium.org/luci/auth/identity"
	"go.chromium.org/luci/common/clock"
	"go.chromium.org/luci/common/errors"
	"go.chromium.org/luci/common/logging"
	"go.chromium.org/luci/gae/service/datastore"
	"go.chromium.org/luci/server/auth"

	configpb "go.chromium.org/luci/swarming/proto/config"
	"go.chromium.org/luci/swarming/server/botinfo"
	"go.chromium.org/luci/swarming/server/botsession"
	"go.chromium.org/luci/swarming/server/botsrv"
	"go.chromium.org/luci/swarming/server/botstate"
	"go.chromium.org/luci/swarming/server/model"
	"go.chromium.org/luci/swarming/server/resultdb"
	"go.chromium.org/luci/swarming/server/tasks"
)

// ClaimCommand instructs the bot what to do after it calls /bot/claim.
type ClaimCommand string

const (
	// ClaimSkip means the bot should skip this task and poll for another one.
	ClaimSkip ClaimCommand = "skip"
	// ClaimRun means the bot should start running the task.
	ClaimRun ClaimCommand = "run"
	// ClaimTerminate means the bot process should gracefully terminate.
	ClaimTerminate ClaimCommand = "terminate"
)

// ClaimRequest is sent by the bot.
type ClaimRequest struct {
	// Session is a serialized Swarming Bot Session proto.
	Session []byte `json:"session"`

	// State is (mostly) arbitrary JSON dict with various properties of the bot.
	//
	// This field is used by the bot to opportunistically report its state changes
	// when running tasks back-to-back (since it doesn't call /bot/poll in that
	// case and has no other way to report the state).
	//
	// Optional.
	State botstate.Dict `json:"state,omitempty"`

	// ClaimID is an opaque string used to make this request idempotent.
	//
	// Generated by the bot (usually derived from RBE's lease ID, but it is not
	// a requirement).
	//
	// Required.
	ClaimID string `json:"claim_id"`

	// TaskID is the TaskResultSummary packed key of the task to claim.
	//
	// The bot takes it from the RBE task payload and passes it to /bot/claim as
	// is, see swarming.internals.rbe.TaskPayload proto message.
	//
	// Required.
	TaskID string `json:"task_id"`

	// TaskToRunShard is an entity class shard index for claimed TaskToRun.
	//
	// The bot takes it from the RBE task payload and passes it to /bot/claim as
	// is, see swarming.internals.rbe.TaskPayload proto message.
	//
	// Required.
	TaskToRunShard int32 `json:"task_to_run_shard"`

	// TaskToRunID identifies TaskToRun to claim.
	//
	// The bot takes it from the RBE task payload and passes it to /bot/claim as
	// is, see swarming.internals.rbe.TaskPayload proto message.
	//
	// Required.
	TaskToRunID int64 `json:"task_to_run_id"`
}

func (r *ClaimRequest) ExtractSession() []byte { return r.Session }
func (r *ClaimRequest) ExtractDebugRequest() any {
	return &ClaimRequest{
		Session:        nil,
		State:          r.State,
		ClaimID:        r.ClaimID,
		TaskID:         r.TaskID,
		TaskToRunShard: r.TaskToRunShard,
		TaskToRunID:    r.TaskToRunID,
	}
}

// ClaimResponse is returned by the server.
type ClaimResponse struct {
	// Cmd instructs the bot what to do next.
	Cmd ClaimCommand `json:"cmd"`

	// Session is a serialized bot session proto.
	//
	// If not empty, contains the refreshed session.
	Session []byte `json:"session,omitempty"`

	// Manifest is full details about the task to execute.
	//
	// Present only for ClaimRun command.
	Manifest *TaskManifest `json:"manifest,omitempty"`

	// TaskID is TaskRunResult ID of the task to run (if any).
	//
	// This is populated for both ClaimRun and ClaimTerminate.
	TaskID string `json:"task_id,omitempty"`

	// Reason is why a task was skipped.
	//
	// Present only for ClaimSkip command.
	Reason string `json:"reason,omitempty"`
}

// TaskManifest describes to the bot what it should execute.
//
// This is based to the task runner process.
type TaskManifest struct {
	// TaskID is TaskRunResult ID of the task to run.
	TaskID string `json:"task_id"`

	// Caches is a list of named caches requested by the task.
	Caches []TaskCache `json:"caches,omitempty"`
	// CIPDInput are CIPD packages that the bot should fetch.
	CIPDInput *model.CIPDInput `json:"cipd_input,omitempty"`
	// Command is the actual command line the bot should execute.
	Command []string `json:"command"`
	// Containment describes the task process containment (not implemented).
	Containment *model.Containment `json:"containment,omitempty"`
	// Dimensions are task's dimension requirements.
	Dimensions model.TaskDimensions `json:"dimensions"`
	// Env is the environment variables to set.
	Env model.Env `json:"env,omitempty"`
	// EnvPrefixes is values to prepend to environment variables.
	EnvPrefixes model.EnvPrefixes `json:"env_prefixes,omitempty"`
	// GracePeriodSecs is how long to wait for the task to gracefully terminate.
	GracePeriodSecs int64 `json:"grace_period"`
	// HardTimeoutSecs is how long to allow the task to run.
	HardTimeoutSecs int64 `json:"hard_timeout"`
	// IOTimeoutSecs is how long to allow no stdout before the task is terminated.
	IOTimeoutSecs int64 `json:"io_timeout"`
	// SecretBytes are the task secret, if any.
	SecretBytes []byte `json:"secret_bytes,omitempty"`
	// CASInputRoot is what CAS files to fetch.
	CASInputRoot *model.CASReference `json:"cas_input_root,omitempty"`
	// Outputs are list of extra outputs to upload to RBE-CAS as task results.
	Outputs []string `json:"outputs,omitempty"`
	// Realm is the task's security realm.
	Realm *TaskRealm `json:"realm,omitempty"`
	// RelativeCwd is the directory to set as current when running the task.
	RelativeCwd string `json:"relative_cwd,omitempty"`
	// ResultDB is parameters of the ResultDB invocation for the task.
	ResultDB *TaskResultDB `json:"resultdb,omitempty"`
	// ServiceAccounts describe what service account a task can use.
	ServiceAccounts TaskServiceAccounts `json:"service_accounts"`

	// BotID is the ID of the calling bot.
	BotID string `json:"bot_id"`
	// BotDimensions are dimensions of the calling bot.
	BotDimensions map[string][]string `json:"bot_dimensions"`
	// BotAuthenticatedAs is how the server authenticated the bot.
	BotAuthenticatedAs identity.Identity `json:"bot_authenticated_as"`
}

// TaskRealm defines the task's security realm.
//
// It is put into LUCI_CONTEXT["realm"].
type TaskRealm struct {
	// Name is the name of the realm.
	Name string `json:"name,omitempty"`
}

// TaskResultDB is parameters of the ResultDB invocation for the task.
//
// It is put into LUCI_CONTEXT["resultdb"].
type TaskResultDB struct {
	// Hostname is the hostname of ResultDB service to use.
	Hostname string `json:"hostname"`
	// CurrentInvocation is the ResultDB invocation to use.
	CurrentInvocation TaskInvocation `json:"current_invocation"`
}

// TaskInvocation is task's ResultDB invocation info.
type TaskInvocation struct {
	// Name is the invocation name.
	Name string `json:"name"`
	// UpdateToken is the invocation's update token.
	UpdateToken string `json:"update_token"`
}

// TaskCache describes one named cache requested by a task.
type TaskCache struct {
	// Name is a logical cache name.
	Name string `json:"name"`
	// Path is where to mount it relative to the task root directory.
	Path string `json:"path"`
	// Hint is cache size hint (as a decimal string) or "-1" if unknown.
	Hint string `json:"hint"`
}

// TaskServiceAccounts are service accounts accessible to the task.
type TaskServiceAccounts struct {
	// System describes "system" logical account.
	System TaskServiceAccount `json:"system"`
	// System describes "task" logical account.
	Task TaskServiceAccount `json:"task"`
}

// TaskServiceAccount describe what service account a task can use.
type TaskServiceAccount struct {
	// ServiceAccount is 'none', 'bot' or an email.
	//
	// Bot interprets 'none' and 'bot' locally. When it sees something else, it
	// uses /bot/oauth_token or /bot/id_token API endpoints to grab tokens through
	// the server.
	ServiceAccount string `json:"service_account"`
}

// Claim implements the handler that claims pending tasks.
//
// This transactionally and idempotently assigns the task to the calling bot or
// instructs the bot to skip the task if it is no longer pending. On success it
// returns detailed description of the task to let the bot know what it should
// be executing.
//
// Called by bots after they get a lease from the RBE.
func (srv *BotAPIServer) Claim(ctx context.Context, body *ClaimRequest, r *botsrv.Request) (botsrv.Response, error) {
	if body.ClaimID == "" {
		return nil, status.Errorf(codes.InvalidArgument, "missing claim ID")
	}
	if body.TaskID == "" {
		return nil, status.Errorf(codes.InvalidArgument, "missing task ID")
	}
	if body.TaskToRunID == 0 {
		return nil, status.Errorf(codes.InvalidArgument, "missing task to run ID")
	}

	// Ignore invalid state. This broken state will be noticed in the next Poll
	// call and move the bot into quarantined state. Here we already have a
	// pending task which the bot must execute. Better to try to execute the task
	// even if something is wrong than just totally ignore it.
	var state *botstate.Dict
	if !body.State.IsEmpty() {
		if err := body.State.Unseal(); err != nil {
			logging.Errorf(ctx, "Ignoring bad state: %s", err)
		} else {
			state = &body.State
		}
	}

	// Fetch the requested TaskToRun.
	reqKey, err := model.TaskIDToRequestKey(ctx, body.TaskID)
	if err != nil {
		return nil, status.Errorf(codes.InvalidArgument, "%s", err)
	}
	ttr := &model.TaskToRun{
		Key: model.TaskToRunKey(ctx, reqKey, body.TaskToRunShard, body.TaskToRunID),
	}
	switch err := datastore.Get(ctx, ttr); {
	case errors.Is(err, datastore.ErrNoSuchEntity):
		// This should not be happening normally.
		return srv.claimSkipped(ctx, ttr, "No such task")
	case err != nil:
		logging.Errorf(ctx, "Datastore error fetching TaskToRun: %s", err)
		return nil, status.Errorf(codes.Internal, "datastore error fetching TaskToRun")
	}

	// Check the bot actually matches the task dimensions. This is a security
	// check. A bot must not be able to claim tasks it doesn't match (in
	// particular in "pool" and/or "id" dimensions) even if the bot knows their
	// IDs.
	filter, err := model.NewFilterFromTaskDimensions(ttr.Dimensions)
	if err != nil {
		return srv.claimSkipped(ctx, ttr, "Task dimensions are unexpectedly broken: %s", err)
	}
	if !filter.MatchesBot(r.Dimensions) {
		return srv.claimSkipped(ctx, ttr, "Dimensions mismatch")
	}

	// Two bots should not be able to collide in their claim IDs.
	claimID := fmt.Sprintf("%s:%s", r.Session.BotId, body.ClaimID)

	// Check the task wasn't claimed yet or was already claimed by us. This will
	// be rechecked under transaction in ClaimOp.ClaimTxn.
	if !ttr.IsReapable() {
		switch existing := ttr.ClaimID.Get(); {
		case existing == "":
			return srv.claimSkipped(ctx, ttr, "The task slice has expired")
		case existing != claimID:
			return srv.claimSkipped(ctx, ttr, "Already claimed by %q", existing)
		}
		// This task was already claimed by us and this is a retry. Just return task
		// details, no need to run any transactions.
		switch details, err := srv.fetchClaimDetails(ctx, ttr, r); {
		case details == nil:
			// This should not be happening.
			return srv.claimSkipped(ctx, ttr, "No such task")
		case err != nil:
			logging.Errorf(ctx, "Error fetching task details: %s", err)
			return nil, status.Errorf(codes.Internal, "datastore error fetching task details")
		default:
			return srv.claimTask(ctx, details, r)
		}
	}

	// Fetch task details outside of the transaction.
	details, err := srv.fetchClaimDetails(ctx, ttr, r)
	switch {
	case details == nil:
		// This should not be happening.
		return srv.claimSkipped(ctx, ttr, "No such task")
	case err != nil:
		logging.Errorf(ctx, "Error fetching task details: %s", err)
		return nil, status.Errorf(codes.Internal, "datastore error fetching task details")
	}

	eventType := model.BotEventTask
	if details.req.IsTerminate() {
		eventType = model.BotEventTerminate
	}

	// Transactionally claim the task and assign it to the bot. The transaction
	// happens as part of the botinfo.Update operation that assigns the task ID
	// to the bot.
	var outcome *tasks.ClaimOpOutcome
	update := &botinfo.Update{
		BotID:         r.Session.BotId,
		EventType:     eventType,
		EventDedupKey: body.ClaimID,
		TasksManager:  srv.tasksManager,
		Prepare: func(ctx context.Context, bot *model.BotInfo) (*botinfo.PrepareOutcome, error) {
			if bot == nil {
				return nil, errors.New("unexpectedly missing BotInfo entity")
			}
			outcome, err = srv.tasksManager.ClaimTxn(ctx, &tasks.ClaimOp{
				Request:             details.req,
				TaskToRunKey:        ttr.Key,
				ClaimID:             claimID,
				BotDimensions:       r.Dimensions.ToMap(),
				BotVersion:          bot.Version,
				BotLogsCloudProject: r.Session.BotConfig.LogsCloudProject,
				BotIdleSince:        bot.IdleSince.Get(),
				BotOwners:           r.BotOwners,
			})

			if err != nil {
				return nil, err
			}
			return &botinfo.PrepareOutcome{
				Proceed: outcome.Claimed,
			}, nil
		},
		State: state,
		CallInfo: botCallInfo(ctx, &botinfo.CallInfo{
			SessionID: r.Session.SessionId,
		}),
		TaskInfo: &botinfo.TaskInfo{
			TaskID:    model.RequestKeyToTaskID(reqKey, model.AsRunResult),
			TaskName:  details.req.Name,
			TaskFlags: details.req.TaskFlags(),
		},
	}
	if err := srv.submitUpdate(ctx, update); err != nil {
		return nil, status.Errorf(codes.Internal, "failed to claim the task: %s", err)
	}

	if outcome.Unavailable != "" {
		return srv.claimSkipped(ctx, ttr, outcome.Unavailable)
	}
	return srv.claimTask(ctx, details, r)
}

// claimDetails are fetched before hitting heavy transactions.
type claimDetails struct {
	req      *model.TaskRequest
	slice    int
	secret   []byte
	caches   []TaskCache
	settings *configpb.SettingsCfg
}

// fetchClaimDetails fetches information about the task slice and named caches.
//
// Returns nil if there's no such task anymore for whatever reason.
func (srv *BotAPIServer) fetchClaimDetails(ctx context.Context, ttr *model.TaskToRun, r *botsrv.Request) (*claimDetails, error) {
	req, err := model.FetchTaskRequest(ctx, ttr.TaskRequestKey())
	switch {
	case errors.Is(err, datastore.ErrNoSuchEntity):
		return nil, nil
	case err != nil:
		return nil, errors.Fmt("fetching TaskRequest: %w", err)
	}
	if req.IsTerminate() {
		return &claimDetails{req: req}, nil
	}

	if ttr.TaskSliceIndex() >= len(req.TaskSlices) {
		// This should not be happening.
		logging.Errorf(ctx,
			"TaskToRun references non-existing slice %d of %s, treating as a missing task",
			ttr.TaskSliceIndex(),
			model.RequestKeyToTaskID(ttr.TaskRequestKey(), model.AsRequest),
		)
		return nil, nil
	}
	slice := &req.TaskSlices[ttr.TaskSliceIndex()]

	var secret []byte
	if slice.Properties.HasSecretBytes {
		secretEnt := &model.SecretBytes{Key: model.SecretBytesKey(ctx, req.Key)}
		switch err := datastore.Get(ctx, secretEnt); {
		case errors.Is(err, datastore.ErrNoSuchEntity):
			// This should not be happening, but just carry on without the secret.
			logging.Errorf(ctx, "SecretBytes for %s is missing, proceeding anyway",
				model.RequestKeyToTaskID(req.Key, model.AsRequest))
		case err != nil:
			return nil, errors.Fmt("fetching SecretBytes: %w", err)
		default:
			secret = secretEnt.SecretBytes
		}
	}

	var caches []TaskCache
	if len(slice.Properties.Caches) != 0 && req.Pool() != "" {
		names := make([]string, len(slice.Properties.Caches))
		for i, c := range slice.Properties.Caches {
			names[i] = c.Name
		}
		hints, err := model.FetchNamedCacheSizeHints(ctx,
			req.Pool(),
			model.OSFamily(r.Dimensions.DimensionValues("os")),
			names,
		)
		if err != nil {
			return nil, errors.Fmt("fetching named caches hints: %w", err)
		}
		caches = make([]TaskCache, len(slice.Properties.Caches))
		for i, c := range slice.Properties.Caches {
			caches[i] = TaskCache{
				Name: c.Name,
				Path: c.Path,
				Hint: fmt.Sprintf("%d", hints[i]),
			}
		}
	}

	cfg, err := srv.cfg.FreshEnough(ctx, r.Session.LastSeenConfig.AsTime())
	if err != nil {
		return nil, errors.Fmt("fetching service config: %w", err)
	}

	return &claimDetails{
		req:      req,
		slice:    ttr.TaskSliceIndex(),
		secret:   secret,
		caches:   caches,
		settings: cfg.Settings(),
	}, nil
}

// claimSkipped returns ClaimSkip response, logging it.
func (srv *BotAPIServer) claimSkipped(ctx context.Context, ttr *model.TaskToRun, reason string, args ...any) (botsrv.Response, error) {
	reason = fmt.Sprintf(reason, args...)
	logging.Warningf(ctx, "Skipping task %s (slice %d): %s",
		model.RequestKeyToTaskID(ttr.TaskRequestKey(), model.AsRunResult),
		ttr.TaskSliceIndex(),
		reason,
	)
	return &ClaimResponse{
		Cmd:    ClaimSkip,
		Reason: reason,
	}, nil
}

// claimTask returns ClaimRun or ClaimTerminate response, logging it.
//
// Also refreshes the bot session if necessary, by making sure the bot config in
// it can last as long as the task is expected to run.
func (srv *BotAPIServer) claimTask(ctx context.Context, d *claimDetails, r *botsrv.Request) (botsrv.Response, error) {
	taskID := model.RequestKeyToTaskID(d.req.Key, model.AsRunResult)

	if d.req.IsTerminate() {
		logging.Infof(ctx, "Claimed termination task %s", taskID)
		return &ClaimResponse{
			Cmd:    ClaimTerminate,
			TaskID: taskID,
		}, nil
	}

	props := &d.req.TaskSlices[d.slice].Properties

	// Bump the config expiration in the session to be long enough to outlive the
	// task (with some fudge factor). This would allow the bot to finish the task
	// even if it is removed from the configs midway through the execution.
	//
	// TODO: Add a mechanism that ensures this expiry extension can't be abused
	// by a malicious bot that was removed from config. Otherwise it can keep
	// running tasks back-to-back forever, continuously bumping config expiry.
	maxRuntimeSec := props.ExecutionTimeoutSecs + props.GracePeriodSecs + 300
	r.Session.BotConfig.Expiry = timestamppb.New(clock.Now(ctx).Add(time.Second * time.Duration(maxRuntimeSec)))
	r.Session.DebugInfo = botsession.DebugInfo(ctx, srv.version)
	r.Session.Expiry = timestamppb.New(clock.Now(ctx).Add(botsession.Expiry))
	session, err := botsession.Marshal(r.Session, srv.hmacSecret)
	if err != nil {
		return nil, status.Errorf(codes.Internal, "fail to marshal session proto: %s", err)
	}

	// Extract ResultDB hostname from settings (if necessary).
	var resultDB string
	if d.req.ResultDBUpdateToken != "" {
		resultDB = d.settings.GetResultdb().GetServer()
		if resultDB != "" {
			if u, err := url.Parse(resultDB); err == nil {
				resultDB = u.Host
			} else {
				logging.Errorf(ctx, "Ignoring invalid ResultDB URL %q", resultDB)
				resultDB = ""
			}
		}
	}

	logging.Infof(ctx, "Claimed task %s (slice %d)", taskID, d.slice)
	return &ClaimResponse{
		Cmd:     ClaimRun,
		Session: session,
		TaskID:  taskID,
		Manifest: &TaskManifest{
			TaskID: taskID,

			Caches:          d.caches,
			CIPDInput:       pick(props.CIPDInput.IsPopulated(), &props.CIPDInput),
			Command:         props.Command,
			Containment:     pick(props.Containment.ContainmentType != 0, &props.Containment),
			Dimensions:      props.Dimensions,
			Env:             props.Env,
			EnvPrefixes:     props.EnvPrefixes,
			GracePeriodSecs: props.GracePeriodSecs,
			HardTimeoutSecs: props.ExecutionTimeoutSecs,
			IOTimeoutSecs:   props.IOTimeoutSecs,
			SecretBytes:     d.secret,
			CASInputRoot:    pick(props.CASInputRoot.CASInstance != "", &props.CASInputRoot),
			Outputs:         props.Outputs,
			Realm: pick(d.req.Realm != "", &TaskRealm{
				Name: d.req.Realm,
			}),
			RelativeCwd: props.RelativeCwd,
			ResultDB: pick(resultDB != "", &TaskResultDB{
				Hostname: resultDB,
				CurrentInvocation: TaskInvocation{
					Name:        resultdb.InvocationName(srv.project, taskID),
					UpdateToken: d.req.ResultDBUpdateToken,
				},
			}),
			ServiceAccounts: TaskServiceAccounts{
				System: TaskServiceAccount{
					ServiceAccount: valueOrNone(r.Session.BotConfig.SystemServiceAccount),
				},
				Task: TaskServiceAccount{
					ServiceAccount: valueOrNone(d.req.ServiceAccount),
				},
			},

			BotID:              r.Session.BotId,
			BotDimensions:      r.Dimensions.ToMap(),
			BotAuthenticatedAs: auth.CurrentIdentity(ctx),
		},
	}, nil
}

// pick returns `t` if yes is true or nil otherwise.
func pick[T any](yes bool, t *T) *T {
	if yes {
		return t
	}
	return nil
}

// valueOrNone returns either val or "none".
func valueOrNone(val string) string {
	if val != "" {
		return val
	}
	return "none"
}
